library(panelView)
library(zoo)
library(synthdid)
library(fect)
library(ggplot2)
library(data.table)
library(parallel)
library(doParallel)

sink("power_results.txt", split=TRUE)

source("staggered_synth.R")

laws = read.csv("statelaws.csv")
turnout = read.csv("stateturnout.csv")
laws = laws[laws$State!="",]
results <- list()

for (i in 1:nrow(laws)) {
  for (j in 2:ncol(laws)) { 
    if (!is.na(laws[i, j])) {
      if (laws[i, j] == "X") {
        results <- c(results, list(data.frame(STATE = laws$State[i], YEAR = colnames(laws)[j], Law = TRUE)))
      } else {
        results <- c(results, list(data.frame(STATE = laws$State[i], YEAR = colnames(laws)[j], Law = FALSE)))
      }
    } else {
      results <- c(results, list(data.frame(STATE = laws$State[i], YEAR = colnames(laws)[j], Law = FALSE)))
    }
  }
}
laws <- do.call(rbind, results)
laws$YEAR <- substring(laws$YEAR, 2, 10000)

data = merge(laws, turnout, by = c("STATE", "YEAR"))
data = data[!(data$STATE %in% c("Texas")), ]
data = data[data$YEAR > 1982, ]
data = data[as.numeric(data$YEAR) %% 4 == 2, ]
data$Y = as.numeric(gsub(",", "", data$VOTE_FOR_HIGHEST_OFFICE)) / as.numeric(gsub(",", "", data$VEP)) * 100

names(data)[names(data) == "Law"]   <- "W"
names(data)[names(data) == "STATE"] <- "State"
names(data)[names(data) == "YEAR"]  <- "Year"

data$W = as.numeric(data$W)


data$group = as.numeric(data$State %in% c("Arizona", "Georgia", "Indiana", "Ohio")) + 2*as.numeric(data$State %in% c("Alabama", "Kansas", "Mississippi", "North Dakota", "Tennessee", "Virginia", "Wisconsin"))

data.sim <- data[data$group %in% 0, ]


n_cores <- detectCores() - 1
cl      <- makeCluster(n_cores)
registerDoParallel(cl)

states   <- unique(data.sim$State)
clusterExport(cl, c("data.sim", "states"))

run_simulation <- function(iter, data.sim, states) {

  set.seed(iter)
  library(fect)
  

  adopters2006 <- sample(states, 3) #3
  remaining    <- setdiff(states, adopters2006)
  
  adopters2008 <- sample(remaining, 1) #1
  remaining    <- setdiff(remaining, adopters2008)
  
  adopters2012 <- sample(remaining, 2) #2 for presidential/all elections, 3 for midterms
  remaining    <- setdiff(remaining, adopters2012)
  
  adopters2014 <- sample(remaining, 2) #2 for presidential/all elections, 3 for midterms
  remaining    <- setdiff(remaining, adopters2014)
  
  adopters2016 <- sample(remaining, 1) #1

  data.sim$W2 <- 0L
  data.sim$W2[data.sim$State %in% adopters2006 & data.sim$Year >= 2006] <- 1L
  data.sim$W2[data.sim$State %in% adopters2008 & data.sim$Year >= 2008] <- 1L
  data.sim$W2[data.sim$State %in% adopters2012 & data.sim$Year >= 2012] <- 1L
  data.sim$W2[data.sim$State %in% adopters2014 & data.sim$Year >= 2014] <- 1L
  data.sim$W2[data.sim$State %in% adopters2016 & data.sim$Year >= 2016] <- 1L
  
  data.sim$group2 <- 0L
  data.sim$group2[data.sim$State %in% c(adopters2006, adopters2008)]                    <- 1L
  data.sim$group2[data.sim$State %in% c(adopters2012, adopters2014, adopters2016)]      <- 2L

  treated <- data.sim$W2 == 1
  sizevec <- as.numeric(gsub(",", "", data.sim$VEP))
  data.sim$Y2 <- data.sim$Y
  data.sim$Y2[treated] <- data.sim$Y[treated] -
    rbinom(sum(treated), size = sizevec[treated], prob = 0.03) /
    sizevec[treated] * 100
  
  fit <- fect(
    Y2       ~ W2,
    data     = data.sim,
    index    = c("State", "Year"),
    method   = "mc",
    se       = TRUE,
    group    = "group2",     
    CV       = TRUE,
    force    = "two-way",
    parallel = FALSE,
    seed     = iter,
    vartype  = "jackknife"
  )
  
  avg  <- fit$est.avg[1, ]         
  gatt <- fit$est.group.att         
  early <- gatt[2, ]
  late  <- gatt[3, ]
  
  list(
    att_est   = as.numeric(avg["ATT.avg"]),
    att_sig   = avg["p.value"]  < 0.05,
    early_est = as.numeric(early["ATT"]),
    early_sig = early["p.value"] < 0.05,
    late_est  = as.numeric(late["ATT"]),
    late_sig  = late["p.value"]  < 0.05
  )
}

niter   <- 1000
results <- parLapply(cl, 1:niter, run_simulation, data.sim = data.sim, states = states)
stopCluster(cl)


att_est   <- sapply(results, `[[`, "att_est")
att_sig   <- sapply(results, `[[`, "att_sig")
early_est <- sapply(results, `[[`, "early_est")
early_sig <- sapply(results, `[[`, "early_sig")
late_est  <- sapply(results, `[[`, "late_est")
late_sig  <- sapply(results, `[[`, "late_sig")

cat("Overall ATT  sig @5%:", sum(att_sig),  "of", niter, "\n")
cat("Early-adopt sig @5%:", sum(early_sig),"of", niter, "\n")
cat("Late-adopt  sig @5%:", sum(late_sig), "of", niter, "\n")

if (sum(att_sig)  > 0) cat("Mean ATT  (signif.):", mean(att_est[att_sig]),  "\n")
if (sum(early_sig)> 0) cat("Mean Early (signif.):", mean(early_est[early_sig]),"\n")
if (sum(late_sig) > 0) cat("Mean Late  (signif.):", mean(late_est[late_sig]), "\n")

sink()
